
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=8, in_c=3, embed_dim=192,norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x

class Attention(nn.Module):
    def __init__(self,
                 dim,   
                 blocksize=7,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        # self.num_heads = num_heads
        # head_dim = dim // num_heads
        # self.scale = qk_scale or head_dim ** -0.5
        self.blocksize =blocksize
        self.qkv = nn.Linear((dim+1), (dim+1) * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear((dim+1), (dim+1))
        self.proj_drop = nn.Dropout(proj_drop_ratio)
        self.PatchEmbed = PatchEmbed()

        # self.HGR_atten = HGR_atten(dim=dim, qk_bias=False)


    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]

        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        Ndim = int(math.sqrt(N) / self.blocksize)
        qkv = self.qkv(x)
        # —>[B, 3C, HW]
        qkv = qkv.permute(0, 2, 1)
        # —>[B, 3, C, H/blocksize, blocksize, W/blocksize, blocksize]
        qkv = qkv.reshape(B, 3, C, Ndim, self.blocksize, Ndim, self.blocksize)
        # —>[3, B, H/blocksize, W/blocksize, C, blocksize,blocksize]
        qkv = qkv.permute(1, 0, 3, 5, 2, 4, 6)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # —>[B, H/blocksize, W/blocksize, C, blocksize,blocksize]
        q, k ,v = q.squeeze(), k.squeeze() ,v.squeeze()
        # —>[B, H/blocksize*W/blocksize, C, blocksize*blocksize]
        q, k = q.reshape(B,q.shape[1]*q.shape[2],q.shape[3],-1), k.reshape(B,k.shape[1]*k.shape[2],q.shape[3],-1)
        # —>[B, H/blocksize*W/blocksize, blocksize*blocksize, C]
        q, k = q.permute(0, 1, 3, 2), k.permute(0, 1, 3, 2)
        # 循环
        # —>[B, H/blocksize*W/blocksize, H/blocksize*W/blocksize]

        attn = torch.zeros((B, Ndim*Ndim, Ndim*Ndim)).cuda()

        Number2samples = q.shape[-2]
        f = torch.nn.functional.normalize(q, dim=-1)

        for i in range(q.shape[1]):
            # print(i)
            g = torch.nn.functional.normalize(k.roll(shifts=i, dims=1), dim=-1)

            corr = torch.sum(torch.sum(f * g, dim=-1), dim=-1)/ Number2samples

            distribution_f = f @ f.transpose(-2,-1)
            distribution_g = g @ g.transpose(-2,-1)

            Number2samples = distribution_f.shape[-2]

            f1 = torch.triu(distribution_f, diagonal=1)
            f1 = f1[:, :, 0:Number2samples - 1, 1:Number2samples]
            f1 = f1 + torch.triu(f1, diagonal=1).transpose(-2, -1)  

            g1 = torch.triu(distribution_g, diagonal=1)
            g1 = g1[:, :, 0:Number2samples - 1, 1:Number2samples]
            g1 = g1 + torch.triu(g1, diagonal=1).transpose(-2, -1)  

            
            f1 = torch.nn.functional.normalize(f1, dim=-1)
            g1 = torch.nn.functional.normalize(g1, dim=-1)

            
            trace = torch.sum(torch.sum(f1 * g1, dim=-1), dim=-1) / (Number2samples - 1)

            attn[:,i,:] = (corr + trace)/1.5

        # 
        # for l in range(attn.shape[0]):
        #     for i in range(attn.shape[1]):
        #         f = q[l, i, :, :]  # 
        #         # print(l,i)
        #         for j in range(attn.shape[1]):
        #             g = k[l, j, :, :]  # 
        #             # attn[l, i, j],__,__= HGRscore3(sub_matrix1, sub_matrix2)  #
        #             # HGR
        #             Number2samples = len(f)
        #
        #             f1 = torch.nn.functional.normalize(f, dim=1)
        #             g1 = torch.nn.functional.normalize(g, dim=1)
        #
        #             corr = torch.sum(torch.sum(f1 * g1, 1)) / Number2samples
        #
        #             distribution_f = torch.mm(f1, torch.t(f1))
        #             distribution_g = torch.mm(g1, torch.t(g1))
        #
        #             # del f1, g1
        #
        #             Number2samples = len(distribution_f)
        #             f1 = torch.triu(distribution_f, diagonal=1)
        #             f1 = f1[0:Number2samples - 1, 1:Number2samples]
        #             f1 = f1 + torch.triu(f1, diagonal=1).transpose(0, 1)
        #
        #             g1 = torch.triu(distribution_g, diagonal=1)
        #             g1 = g1[0:Number2samples - 1, 1:Number2samples]
        #             g1 = g1 + torch.triu(g1, diagonal=1).transpose(0, 1)
        #
        #             f1 = torch.nn.functional.normalize(f1, dim=1)
        #             g1 = torch.nn.functional.normalize(g1, dim=1)
        #
        #             tra = torch.sum(torch.sum(f1 * g1, 1)) / (Number2samples - torch.tensor(1))
        #
        #             # del g, f1, g1, distribution_f, distribution_g, Number2samples
        #
        #             attn[l, i, j] = (torch.tensor(1.5) - corr - tra / 2) / 1.5

                # del f

        # attn = attn.permute(0, 2, 1, 3)
        # attn = attn.reshape(attn.shape[0], attn.shape[1], -1)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]

        # 
        # —>[B, N, N]
        attn.reshape(B, -1, attn.shape[1])
        attn = attn.unsqueeze(1)
        attn = F.interpolate(attn, size=(N, N), mode='bicubic', align_corners=False).squeeze(1)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]


        x = (attn @ v.reshape(B,N,C)).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class HGR_atten(nn.Module):
    def __init__(self,
                 dim,
                 qk_bias=False):
        super(HGR_atten, self).__init__()
        self.qkv = nn.Linear(dim, dim * 2, bias=qk_bias)


    def forward(self,x):
        x = x.reshape(x.shape[0], x.shape[1] //self.patch_size**2,self.patch_size, self.patch_size, x.shape[2])
        B, N,H,W,C = x.shape
        qk = self.qkv(x).reshape(B, N, 2, H, W, self.num_heads, C // self.num_heads).permute(2, 0, 5, 1, 6, 3, 4)
        q, k = qk[0], qk[1]

        attn,__,__ = HGRscore3(q, k)
        attn = attn.permute(0,2,1,3)
        attn = attn.reshape(attn.shape[0],attn.shape[1],-1)

        return attn



def HGRscore3(f, g):

    Number2samples = len(f)

    f1 = torch.nn.functional.normalize(f, dim=1)
    g1 = torch.nn.functional.normalize(g, dim=1)

    corr = torch.sum(torch.sum(f1 * g1, 1)) / Number2samples

    distribution_f = torch.mm(f1,torch.t(f1))
    distribution_g = torch.mm(g1,torch.t(g1))

    del f1, g1

    Number2samples = len(distribution_f)
    f1 = torch.triu(distribution_f, diagonal=1)
    f1 = f1[0:Number2samples - 1, 1:Number2samples]
    f1 = f1 + torch.triu(f1, diagonal=1).transpose(0, 1)

    g1 = torch.triu(distribution_g, diagonal=1)
    g1 = g1[0:Number2samples - 1, 1:Number2samples]
    g1 = g1 + torch.triu(g1, diagonal=1).transpose(0, 1)

    f1 = torch.nn.functional.normalize(f1, dim=1)
    g1 = torch.nn.functional.normalize(g1, dim=1)

    tra = torch.sum(torch.sum(f1 * g1, 1)) / (Number2samples - torch.tensor(1))

    del f1, g1, distribution_f, distribution_g, Number2samples

    # result = corr - tra/2
    result = (torch.tensor(1.0) - corr - tra/2)/1.5
    return result, corr, tra


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.randn(8,3,224,224).to(device)
    attention = Attention(dim=64).to(device)
    Patch_Embed = PatchEmbed(embed_dim=64).to(device)

   
    x = Patch_Embed(x)
    out = attention(x)
    print(out)

